#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Tue Apr 19 15:54:38 2022

@author: qiguangyao
"""


#%%Lib
import copy
import numpy as np
import pickle as pkl
import matplotlib.pyplot as plt
from scipy import stats
import seaborn as sns 
from scipy import asarray as ar,exp
from scipy.optimize import curve_fit
import math
import pingouin as pg
from sklearn import linear_model
from pylab import cos
import pandas as pd
import random
from statsmodels.formula.api import ols
from statsmodels.stats.anova import anova_lm
from statsmodels.sandbox.stats.multicomp import multipletests # for multiple comparisons correction
from statsmodels.stats.multicomp import pairwise_tukeyhsd
print("__file Output:",__file__)
#%%functions
import scipy.stats
def mean_confidence_interval(data, confidence=0.95):
    a = 1.0 * np.array(data)
    n = len(a)
    m, se = np.mean(a), scipy.stats.sem(a)
    h = se * scipy.stats.t.ppf((1 + confidence) / 2., n-1)
    return m, m-h, m+h

def adjust_spines(ax, spines):
    for loc, spine in ax.spines.items():
        if loc in spines:
            spine.set_position(('outward', 10))  # outward by 10 points
        else:
            spine.set_color('none')  # don't draw spine

    # turn off ticks where there is no spine
    if 'left' in spines:
        ax.yaxis.set_ticks_position('left')
    else:
        # no yaxis ticks
        ax.yaxis.set_ticks([])

    if 'bottom' in spines:
        ax.xaxis.set_ticks_position('bottom')
    else:
        # no xaxis ticks
        ax.xaxis.set_ticks([])
        
def gaus(x,a,x0,sigma):
    return a*(1/sigma*np.sqrt(2*np.pi))*exp(-(x-x0)**2/(2*sigma**2))

def gaussian(X, amp, cen, wid):
    return amp * exp(-(X-cen)**2 / wid)

def getPossionPDF(mu,x):
    if x > 170:
        x =170
    mu = mu + 0.01
    if x<0:
        x = 0
    # x[x<0]=0
    x = copy.deepcopy(round(x))
    out = math.exp(-mu)*(mu**x)/math.factorial(x)
    if out<0:
        out = 0
    return out

#tuning curve fitting
def vonMisesFunction(x,b,a,u):
    # import math
#    print(x - u)
    out = b + a*cos(x - u)
    out = np.array(out)
    out[out<0]=0
    # if out<0:
    #     out = 0
    return out

def getvonMisesParas(x,y):
    """
    x:hand position
    y:firing rate
    """
    init_vals = [1, 0, 1]  # for [b,a,u]
    best_vals, covar = curve_fit(vonMisesFunction, x, y, p0=init_vals,maxfev=500000)
    return best_vals

def getExpParas(x,y):
    """
    x:hand position
    y:firing rate
    """
    init_vals = [1, 0, 1]  # for [b,a,u]
    best_vals, covar = curve_fit(expFunction, x, y, p0=init_vals,maxfev=500000)
    return best_vals

def expFunction(x, a, b, c):
    return a * np.exp(-b * x) + c



#%% ------------figure 6---------------- 
fig6Data = pkl.load(open('fig6Data.pickle','rb'))

#6A
delaPEV = fig6Data['delaPEV']
vpHandFRMean = fig6Data['vpHandFRMean']
vpcHandZeroDispFRMean = fig6Data['vpcHandZeroDispFRMean']
vpHandFRSEM = fig6Data['vpHandFRSEM']
vpcHandZeroDispFRSEM = fig6Data['vpcHandZeroDispFRSEM']

#6B
PMCPEVFBinsVP = fig6Data['PMCPEVFBinsVP']
PMCPEVFBinsVPC = fig6Data['PMCPEVFBinsVPC']
Area5PEVFBinsVP = fig6Data['Area5PEVFBinsVP']
Area5PEVFBinsVPC = fig6Data['Area5PEVFBinsVPC']

#6C
PMCVPHand = fig6Data['PMCVPHand']
PMCVPHandLow = fig6Data['PMCVPHandLow']
PMCVPHandUp = fig6Data['PMCVPHandUp']
PMCVPCHand = fig6Data['PMCVPCHand']
PMCVPCHandLow = fig6Data['PMCVPCHandLow']
PMCVPCHandUp = fig6Data['PMCVPCHandUp']

Area5VPHand = fig6Data['Area5VPHand']
Area5VPHandLow = fig6Data['Area5VPHandLow']
Area5VPHandUp = fig6Data['Area5VPHandUp']
Area5VPCHand = fig6Data['Area5VPCHand']
Area5VPCHandLow = fig6Data['Area5VPCHandLow']
Area5VPCHandUp = fig6Data['Area5VPCHandUp']


#%%fig6A left
keySele = '20180620N1SPK01a'
with plt.style.context('style_paper.mplstyle'):
    colors = plt.rcParams["axes.prop_cycle"].by_key()["color"]
    plt.figure(figsize = [3.54/2,3.54/2])
    plt.plot([-.8+k*.1 for k in range(22)],delaPEV,lw = 1,color = 'k')
    plt.scatter(-.8+13*.1, delaPEV[13],s = 200, color = 'gray')
    plt.plot([-.8+13*.1 for k in range(22)], np.linspace(-.07,delaPEV[13],22),'--',lw = 1,color = 'gray')
    plt.ylim(bottom = -.07)
    plt.ylabel('Delta ωPEV')
    plt.xlabel('Time from target onset (s)')
    plt.xlim([-1,1])
    plt.tight_layout()
    plt.xticks(np.arange(-1, 1.51, step=.5))
    plt.xlim([-.9,1.4])
    plt.tight_layout()
    fileName = 'fig6A_Left_Area5'+'DeltaωPEV'+keySele+'.pdf'
    # plt.savefig(fileName,dpi = 600)
plt.show()
#%%fig6A right
keySele = '20180620N1SPK01a'
targetDirections = np.array([-30, -20,   0,  20,  30])
targetDirectionsRad = []
for i in range(5):
    targetDirectionsRad.append(math.radians(targetDirections[i]))
parasVP = getvonMisesParas(targetDirectionsRad,vpHandFRMean[:,13])
rangeDire = np.linspace(-.53, .53, 40)
vpHandFRMeanFitting = []
for i in range(len(rangeDire)):
    vpHandFRMeanFitting.append(vonMisesFunction(rangeDire[i],parasVP[0],parasVP[1],parasVP[2]))
parasVPC = getvonMisesParas(targetDirectionsRad,vpcHandZeroDispFRMean[:,13])
rangeDire = np.linspace(-.53, .53, 40)
vpcHandFRMeanFitting = []
for i in range(len(rangeDire)):
    vpcHandFRMeanFitting.append(vonMisesFunction(rangeDire[i],parasVPC[0],parasVPC[1],parasVPC[2]))
rangeDireDeg = [math.degrees(i) for i in rangeDire]
with plt.style.context('style_paper.mplstyle'):
    colors = plt.rcParams["axes.prop_cycle"].by_key()["color"]
    fig, ax = plt.subplots(figsize = [3.54/2,3.54/2])
    ax.plot(rangeDireDeg,vpHandFRMeanFitting,'-',color = colors[0],lw = 1)
    ax.errorbar(targetDirections-.5, vpHandFRMean[:,13], yerr=vpHandFRSEM[:,13], fmt='.',color = colors[0],label = 'VP',lw = 1, elinewidth = .8)
    ax.plot(rangeDireDeg,vpcHandFRMeanFitting,'-',color = 'k',lw = 1)
    ax.errorbar(targetDirections+.5, vpcHandZeroDispFRMean[:,13], yerr=vpcHandZeroDispFRSEM[:,13], fmt='.',color = 'k',label = 'VPC (0°)',lw = 1, elinewidth = .8)
    plt.xlabel('Hand location (deg)')
    plt.ylabel('Firing rate (Hz)')
    plt.legend(loc = 'upper center',bbox_to_anchor=[0.5,1.1])
    adjust_spines(ax, ['left', 'bottom'])
    plt.tight_layout()
    fileName = 'fig6A_Right_Area5'+'tuningCurve'+keySele+'.pdf'
    # plt.savefig(fileName,dpi = 600)
plt.show()

#%%fig6B
with plt.style.context('style_paper.mplstyle'):
    colors = plt.rcParams["axes.prop_cycle"].by_key()["color"]
    f, (ax1, ax2) = plt.subplots(ncols=2, nrows=1, sharex=True,sharey=True,figsize=[3.54,3.54/2])#raw
    ax1.plot([i/10-.8 for i in range(22)],np.nanmean(PMCPEVFBinsVP,axis = 0),color=colors[0],linewidth = 1,label = 'VP')
    ax1.plot([i/10-.8 for i in range(22)],np.nanmean(PMCPEVFBinsVPC,axis = 0),'-',color='k',linewidth = 1,label = 'VPC (0°)')
    ax1.legend(loc = 'upper left')
    ax1.set_title('Premotor',fontsize = 10)
    ax1.set_xticks(np.arange(-1, 1.51, step=.5))
    ax1.set_yticks(np.arange(-0.05, 0.251, step=0.1))
    ax1.set_ylim([-.05,0.15])
    ax1.set_xlim([-.9,1.4])
    
    #Area5
    ax2.plot([i/10-.8 for i in range(12,18)], [.1 for i in range(12,18)] ,'-', color = 'k',linewidth = 1)
    ax2.plot([i/10-.8 for i in range(22)],np.nanmean(Area5PEVFBinsVP,axis = 0),color=colors[0],linewidth = 1,label = 'VP')    
    ax2.plot([i/10-.8 for i in range(22)],np.nanmean(Area5PEVFBinsVPC,axis = 0),'-',color='k',linewidth = 1,label = 'VPC (0°)')
    ax2.set_title('Parietal',fontsize = 10)
    ax1.set_ylabel('ωPEV')
    ax1.set_xlabel('Time from target onset (s)')
    plt.tight_layout()
    fileName = 'fig6B_zeroDispωPEVMeanDynamicsHand.pdf'
    # plt.savefig(fileName,dpi = 600)
plt.show()
#%%fig6C
with plt.style.context('style_paper.mplstyle'):
    colors = plt.rcParams["axes.prop_cycle"].by_key()["color"]
    f, (ax1, ax2) = plt.subplots(ncols=2, nrows=1, sharex=True,sharey=True,figsize=[3.54,3.54/2])
    #PMC
    ax1.plot([i/10-.8 for i in range(22)], np.nanmean(PMCVPHand[range(50)],axis = 0),'-', color = colors[0],linewidth = 1,label ='VP' )
    ax1.fill_between([i/10-.8 for i in range(22)],PMCVPHandLow, PMCVPHandUp,edgecolor=colors[0], facecolor=colors[0],alpha=0.3)
    ax1.plot([i/10-.8 for i in range(22)], np.nanmean(PMCVPCHand[range(50)],axis = 0),'-', color = 'k',linewidth = 1,label = 'VPC (0°)')
    ax1.fill_between([i/10-.8 for i in range(22)],PMCVPCHandLow, PMCVPCHandUp,edgecolor='k', facecolor='k',alpha=0.3)
    ax1.plot([i/10-.8 for i in range(22)], [0.2 for i in range(22)],'--', color = 'k',linewidth = 1)
    ax1.set_ylabel('Decoding accuracy')
    ax1.set_xlabel('Time from target onset (s)')
    ax1.legend(loc = 'upper left')
    ax1.set_title('Premotor',fontsize = 10)
    ax1.set_xticks(np.arange(-1, 1.51, step=.5))
    ax1.set_yticks(np.arange(0, .7, step=.2))
    ax1.set_ylim([.05,.7])
    ax1.set_xlim([-.9,1.4])
    #Area5
    ax2.plot([i/10-.8 for i in range(22)], np.nanmean(Area5VPHand[range(50)],axis = 0),'-', color = colors[0],linewidth = 1,label ='VP' )
    ax2.fill_between([i/10-.8 for i in range(22)],Area5VPHandLow, Area5VPHandUp,edgecolor=colors[0], facecolor=colors[0],alpha=0.3)
    ax2.plot([i/10-.8 for i in range(22)], np.nanmean(Area5VPCHand[range(50)],axis = 0),'-', color = 'k',linewidth = 1,label = 'VPC (0°)')
    ax2.fill_between([i/10-.8 for i in range(22)],Area5VPCHandLow, Area5VPCHandUp,edgecolor='k', facecolor='k',alpha=0.3)
    ax2.plot([i/10-.8 for i in range(13,20)],[0.65 for i in range(13,20)],'-', color = 'k',linewidth = 1)
    ax2.set_title('Parietal',fontsize = 10)
    ax2.plot([i/10-.8 for i in range(22)], [0.2 for i in range(22)],'--', color = 'k',linewidth = 1)

    plt.tight_layout()
    fileName = 'fig6C_zeroDispDecodingTargetHand.pdf'
    # plt.savefig(fileName,dpi = 600)
plt.show()